{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# WGAN Training"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "import os\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from models.WGAN import WGAN\n",
    "from utils.loaders import load_cifar\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# run params\n",
    "SECTION = 'gan'\n",
    "RUN_ID = '0002'\n",
    "DATA_NAME = 'horses'\n",
    "RUN_FOLDER = 'run/{}/'.format(SECTION)\n",
    "RUN_FOLDER += '_'.join([RUN_ID, DATA_NAME])\n",
    "\n",
    "if not os.path.exists(RUN_FOLDER):\n",
    "    os.mkdir(RUN_FOLDER)\n",
    "    os.mkdir(os.path.join(RUN_FOLDER, 'viz'))\n",
    "    os.mkdir(os.path.join(RUN_FOLDER, 'images'))\n",
    "    os.mkdir(os.path.join(RUN_FOLDER, 'weights'))\n",
    "\n",
    "mode =  'build' #'load' #\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if DATA_NAME == 'cars':\n",
    "    label = 1\n",
    "elif DATA_NAME == 'horses':\n",
    "    label = 7\n",
    "(x_train, y_train) = load_cifar(label, 10)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow((x_train[150,:,:,:]+1)/2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## architecture"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "if mode == 'build':\n",
    "\n",
    "    gan = WGAN(input_dim = (32,32,3)\n",
    "            , critic_conv_filters = [32,64,128,128]\n",
    "            , critic_conv_kernel_size = [5,5,5,5]\n",
    "            , critic_conv_strides = [2,2,2,1]\n",
    "            , critic_batch_norm_momentum = None\n",
    "            , critic_activation = 'leaky_relu'\n",
    "            , critic_dropout_rate = None\n",
    "            , critic_learning_rate = 0.00005\n",
    "            , generator_initial_dense_layer_size = (4, 4, 128)\n",
    "            , generator_upsample = [2,2, 2,1]\n",
    "            , generator_conv_filters = [128,64,32,3]\n",
    "            , generator_conv_kernel_size = [5,5,5,5]\n",
    "            , generator_conv_strides = [1,1, 1,1]\n",
    "            , generator_batch_norm_momentum = 0.8\n",
    "            , generator_activation = 'leaky_relu'\n",
    "            , generator_dropout_rate = None\n",
    "            , generator_learning_rate = 0.00005\n",
    "            , optimiser = 'rmsprop'\n",
    "            , z_dim = 100\n",
    "            )\n",
    "    gan.save(RUN_FOLDER)\n",
    "\n",
    "else:\n",
    "    gan.load_weights(os.path.join(RUN_FOLDER, 'weights/weights.h5'))\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gan.critic.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gan.generator.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "BATCH_SIZE = 128\n",
    "EPOCHS = 6000\n",
    "PRINT_EVERY_N_BATCHES = 5\n",
    "N_CRITIC = 5\n",
    "CLIP_THRESHOLD = 0.01"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "gan.train(     \n",
    "    x_train\n",
    "    , batch_size = BATCH_SIZE\n",
    "    , epochs = EPOCHS\n",
    "    , run_folder = RUN_FOLDER\n",
    "    , print_every_n_batches = PRINT_EVERY_N_BATCHES\n",
    "    , n_critic = N_CRITIC\n",
    "    , clip_threshold = CLIP_THRESHOLD\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gan.sample_images(RUN_FOLDER)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure()\n",
    "plt.plot([x[0] for x in gan.d_losses], color='black', linewidth=0.25)\n",
    "\n",
    "plt.plot([x[1] for x in gan.d_losses], color='green', linewidth=0.25)\n",
    "plt.plot([x[2] for x in gan.d_losses], color='red', linewidth=0.25)\n",
    "plt.plot(gan.g_losses, color='orange', linewidth=0.25)\n",
    "\n",
    "plt.xlabel('batch', fontsize=18)\n",
    "plt.ylabel('loss', fontsize=16)\n",
    "\n",
    "# plt.xlim(0, 2000)\n",
    "# plt.ylim(0, 2)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compare_images(img1, img2):\n",
    "    return np.mean(np.abs(img1 - img2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "r, c = 5, 5\n",
    "\n",
    "idx = np.random.randint(0, x_train.shape[0], BATCH_SIZE)\n",
    "true_imgs = (x_train[idx] + 1) *0.5\n",
    "\n",
    "fig, axs = plt.subplots(r, c, figsize=(15,15))\n",
    "cnt = 0\n",
    "\n",
    "for i in range(r):\n",
    "    for j in range(c):\n",
    "        axs[i,j].imshow(true_imgs[cnt], cmap = 'gray_r')\n",
    "        axs[i,j].axis('off')\n",
    "        cnt += 1\n",
    "fig.savefig(os.path.join(RUN_FOLDER, \"images/real.png\"))\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "r, c = 5, 5\n",
    "noise = np.random.normal(0, 1, (r * c, gan.z_dim))\n",
    "gen_imgs = gan.generator.predict(noise)\n",
    "\n",
    "#Rescale images 0 - 1\n",
    "\n",
    "gen_imgs = 0.5 * (gen_imgs + 1)\n",
    "# gen_imgs = np.clip(gen_imgs, 0, 1)\n",
    "\n",
    "fig, axs = plt.subplots(r, c, figsize=(15,15))\n",
    "cnt = 0\n",
    "\n",
    "for i in range(r):\n",
    "    for j in range(c):\n",
    "        axs[i,j].imshow(np.squeeze(gen_imgs[cnt, :,:,:]), cmap = 'gray_r')\n",
    "        axs[i,j].axis('off')\n",
    "        cnt += 1\n",
    "fig.savefig(os.path.join(RUN_FOLDER, \"images/sample.png\"))\n",
    "plt.close()\n",
    "\n",
    "\n",
    "fig, axs = plt.subplots(r, c, figsize=(15,15))\n",
    "cnt = 0\n",
    "\n",
    "for i in range(r):\n",
    "    for j in range(c):\n",
    "        c_diff = 99999\n",
    "        c_img = None\n",
    "        for k_idx, k in enumerate((x_train + 1) * 0.5):\n",
    "            \n",
    "            diff = compare_images(gen_imgs[cnt, :,:,:], k)\n",
    "            if diff < c_diff:\n",
    "                c_img = np.copy(k)\n",
    "                c_diff = diff\n",
    "        axs[i,j].imshow(c_img, cmap = 'gray_r')\n",
    "        axs[i,j].axis('off')\n",
    "        cnt += 1\n",
    "\n",
    "fig.savefig(os.path.join(RUN_FOLDER, \"images/sample_closest.png\"))\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gdl_code",
   "language": "python",
   "name": "gdl_code"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
